import os
import sys
import gzip
import pysam
from Bio import SeqIO


assembly = 'hg38'

def read_chromosome_sizes(assembly):
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    chromosomes = []
    sizes = []
    for line in handle:
        chromosome, size = line.split()
        if chromosome.endswith("_alt"):
            continue
        chromosomes.append(chromosome)
        sizes.append(int(size))
    handle.close()
    return chromosomes, sizes

def generate_unmapped_alignments():
    filenames = []
    for filename in os.listdir("."):
        basename, extension = os.path.splitext(filename)
        if extension != ".fa":
            continue
        try:
            seqlist, number = basename.split("_")
        except ValueError:
            continue
        assert seqlist == "seqlist"
        number = int(number)
        filenames.append([number, filename])
    filenames.sort()
    n = 0
    for number, filename in filenames:
        assert number == n
        n += 1
    filenames = [filename for number, filename in filenames]
    for filename in filenames:
        stream = open(filename)
        records = SeqIO.parse(filename, "fasta")
        for record in records:
            query_name = record.id
            alignment = pysam.AlignedSegment()
            alignment.query_name = query_name
            alignment.is_unmapped = True
            yield alignment
        stream.close()

def generate_skipped_alignments():
    filename = "skipped.fa"
    stream = open(filename)
    records = SeqIO.parse(filename, "fasta")
    for record in records:
        query_name = record.id
        alignment = pysam.AlignedSegment()
        alignment.query_name = query_name
        alignment.is_unmapped = True
        alignment.is_qcfail = True
        yield alignment
    stream.close()

def get_genome_location(alignment):
    chromosome = alignment.reference_name
    start = alignment.reference_start
    end = alignment.reference_end
    if alignment.is_reverse:
        strand = "-"
    else:
        strand = "+"
    if start >= end:
        print(alignment1.is_reverse, alignment2.is_reverse)
        print(alignment1)
        print(alignment2)
    assert start < end
    cigar = alignment.cigar
    return (chromosome, start, end, strand, cigar)

def get_transcripts(alignments):
    transcripts = []
    for alignment in alignments:
         transcript = alignment.get_tag("XR")
         transcripts.append(transcript)
    transcripts = sorted(set(transcripts))
    return ",".join(transcripts)

def write_alignments(output, alignments, chromosomes, target):
    if target in ("unmapped", "skipped"):
        assert len(alignments) == 1
        alignment = alignments[0]
        output.write(alignment)
        return
    shortest_length = None
    if target == "genome":
        selected_alignments = []
        for alignment in alignments:
            if alignment.is_unmapped:
                if shortest_length is None:
                    selected_alignments.append(alignment)
            else:
                length = alignment.reference_end - alignment.reference_start
                if shortest_length is None or length < shortest_length:
                    shortest_length = length
                    selected_alignments.clear()
                selected_alignments.append(alignment)
        alignments = selected_alignments
    else:
        for alignment in alignments:
            length = alignment.get_tag("XL")
            assert length is not None
            if shortest_length is None:
                shortest_length = length
            else:
                assert shortest_length == length
    if any(alignment.reference_name in chromosomes for alignment in alignments):
        alignments = [alignment for alignment in alignments if alignment.reference_name in chromosomes]
        alignments.sort(key=get_genome_location)
        current = None
        block = []
        for alignment in alignments:
            location = get_genome_location(alignment)
            if location != current:
                if block:
                    current_alignment = block[0]
                    if target in ("snRNA", "scRNA", "snoRNA", "scaRNA",
                                  "mRNA", "lncRNA", "gencode", "fantomcat"):
                        transcripts = get_transcripts(block)
                        current_alignment.set_tag("XR", transcripts)
                    assert current_alignment.get_tag("XT") == target
                    output.write(current_alignment)
                    block.clear()
                current = location
            block.append(alignment)
        if block:
            current_alignment = block[0]
            if target in ("snRNA", "scRNA", "snoRNA", "scaRNA",
                          "mRNA", "lncRNA", "gencode", "fantomcat"):
                transcripts = get_transcripts(block)
                current_alignment.set_tag("XR", transcripts)
            assert current_alignment.get_tag("XT") == target
            output.write(current_alignment)
    else:
        query_name = alignments[0].query_name
        alignment = pysam.AlignedSegment()
        alignment.query_name = query_name
        alignment.is_unmapped = True
        alignment.set_tag("XT", target)
        if target in ("snRNA", "scRNA", "snoRNA", "scaRNA",
                      "mRNA", "lncRNA", "gencode", "fantomcat"):
            transcripts = get_transcripts(alignments)
            alignment.set_tag("XR", transcripts)
        output.write(alignment)

chromosomes, sizes = read_chromosome_sizes(assembly)

filename = "seqlist.fa"
stream = open(filename)
records = SeqIO.parse(stream, "fasta")

directory = "/osc-fs_home/mdehoon/Data/CASPARs/StartSeq"
subdirectory = "BAM"
subdirectory = os.path.join(directory, subdirectory)
filenames = os.listdir(subdirectory)

filename = "seqlist.bam"
print("Writing", filename)
output = pysam.AlignmentFile(filename, "wb", reference_names=chromosomes, reference_lengths=sizes)

alignments = {}
for filename in filenames:
    terms = filename.split(".")
    if len(terms) != 2:
        continue
    if terms[1] != 'bam':
        continue
    target = terms[0]
    path = os.path.join(subdirectory, filename)
    print("Reading", path)
    alignments[target] = pysam.AlignmentFile(path)
    # alignments[target].header may have additional keys beyond those in
    # output.header.
    for key in output.header.keys():
        assert output.header[key] == alignments[target].header[key]

alignments['unmapped'] = generate_unmapped_alignments()
alignments['skipped'] = generate_skipped_alignments()

cached = {}
for target in alignments:
    try:
        alignment = next(alignments[target])
    except StopIteration:
        print("No alignments for %s" % target)
        continue
    cached[target] = alignment

for number, record in enumerate(records):
    query_name = "seq_%08d" % number
    assert record.id == query_name
    for target in cached:
        alignment = cached[target]
        if alignment.query_name == query_name:
            break
    else:
        raise Exception("Failed to find alignment for %s" % record.id)
    current_alignments = [alignment]
    for alignment in alignments[target]:
        if alignment.query_name != query_name:
            cached[target] = alignment
            break
        current_alignments.append(alignment)
    else:
        alignments[target].close()
        del cached[target]
    write_alignments(output, current_alignments, chromosomes, target)
output.close()
stream.close()
